# =============================================================================
# Statistical Analysis Code for:
# "Chronic adaptive versus conventional DBS response patterns in Parkinson's
#  disease: A pilot randomized crossover trial"
#
# Bayesian Analysis: Mixed-Effects ANCOVA for aDBS vs cDBS Crossover Trial
# Author: Jun Tanimura, MD, MSc.
# Created: 2025-11-06
# Last Updated: 2025-12-06
# Version: v2.1
#
# This code performs Bayesian analysis for ON time and UPDRS Part III
# using the same data structure and formula as the main ANCOVA analysis
# (02_ANCOVA_v3.1_FULL.R)
#
# Key features:
# - Uses brms package (Stan for MCMC sampling)
# - Two prior specifications: weak (non-informative) and literature-informed
# - Creates posterior density plots showing probability distributions
# - Convergence diagnostics: Rhat, effective sample size, divergent transitions
#
# The Bayesian GLMM method is based on:
# Baba, S. (2019). Introduction to Data Analysis by Bayesian Statistical 
#   Modeling using R and Stan. Kodansha.
# =============================================================================

# Load required packages
library(tidyverse)
library(brms)
library(ggplot2)
library(bayesplot)
library(posterior)
library(patchwork)

# Read the data (same as main analysis)
df <- read_csv("data.csv")

# Define evaluation items (same as main analysis)
eval_items_ordered <- c("ADL", "UPDRS_part_1", "UPDRS_part_2", "UPDRS_part_3", "UPDRS_part_4", 
                        "UPDRS_total", "Mean_corr_ON_time", "Mean_corr_OFF_time", 
                        "Mean_corr_Dys_sev_time", "Mean_corr_Dys_weak_time", "MMSE", "PDQ-39")

# Transform data to long format for ANCOVA (same as main analysis)
ancova_data <- map_dfr(eval_items_ordered, function(item) {
  col_P <- paste0("P_", item)
  col_C <- paste0("C_", item)
  col_A <- paste0("A_", item)
  
  if(all(c(col_P, col_C, col_A) %in% names(df))) {
    # Create long format data
    patient_data <- df %>%
      select(PatientID, Group, all_of(c(col_P, col_C, col_A))) %>%
      pivot_longer(
        cols = c(all_of(col_C), all_of(col_A)),
        names_to = "Treatment_raw",
        values_to = "Score_post"
      ) %>%
      mutate(
        Baseline = !!sym(col_P),
        Treatment = case_when(
          str_detect(Treatment_raw, "^C_") ~ "cDBS",
          str_detect(Treatment_raw, "^A_") ~ "aDBS"
        ),
        Period = case_when(
          Group == 1 & Treatment == "cDBS" ~ "Phase1",
          Group == 1 & Treatment == "aDBS" ~ "Phase2", 
          Group == 2 & Treatment == "aDBS" ~ "Phase1",
          Group == 2 & Treatment == "cDBS" ~ "Phase2"
        ),
        Evaluation = item
      ) %>%
      mutate(
        # Apply log1p transformation for time variables
        is_time_var = str_detect(item, "time"),
        Score_transformed = ifelse(is_time_var, log1p(pmax(Score_post, 0)), Score_post),
        Baseline_transformed = ifelse(is_time_var, log1p(pmax(Baseline, 0)), Baseline)
      ) %>%
      select(PatientID, Group, Evaluation, Treatment, Period, 
             Baseline, Score_post, Baseline_transformed, Score_transformed)
  }
}) %>%
  filter(!is.na(Score_post)) %>%
  mutate(
    Treatment = factor(Treatment, levels = c("cDBS", "aDBS")),  # cDBS as reference
    Period = factor(Period, levels = c("Phase1", "Phase2")),
    Evaluation = factor(Evaluation, levels = eval_items_ordered),
    PatientID = factor(PatientID)
  )

# Add z-score normalization for baseline values within each evaluation
ancova_data <- ancova_data %>%
  group_by(Evaluation) %>%
  mutate(
    Baseline_z = as.numeric(scale(Baseline_transformed))
  ) %>%
  ungroup()

# Add sequence variable
ancova_data <- ancova_data %>%
  mutate(Sequence = factor(Group))  # Group 1: cDBS->aDBS, Group 2: aDBS->cDBS

print("=== Data transformation completed ===")
print(paste("Total observations:", nrow(ancova_data)))
print(paste("Unique evaluations:", length(unique(ancova_data$Evaluation))))

# =============================================================================
# Define target outcomes and MCID values
# =============================================================================

target_outcomes <- c("Mean_corr_ON_time", "UPDRS_part_3")

# MCID values (on transformed scale for ON time, original scale for UPDRS3)
# Note: ON time uses log1p scale, so MCID of +2h requires transformation.
# MCID probabilities are calculated on the transformed scale but reported in 
# back-transformed original scale units.
mcid_values <- list(
  "Mean_corr_ON_time" = 2.0,    # +2 hours (on original scale)
  "UPDRS_part_3" = -5.0         # -5 points (lower is better)
)

# =============================================================================
# Function to get prior specifications
# =============================================================================

get_prior_spec <- function(outcome, prior_type, outcome_data = NULL) {
  if (outcome == "Mean_corr_ON_time") {
    # Prior is specified on original time scale (hours)
    # Convert to log1p scale (logarithmic transformation) for statistical modeling
    
    if (is.null(outcome_data)) {
      stop("outcome_data is required for ON time prior conversion")
    }
    
    # Get representative ON time value (mean of cDBS group, or overall mean)
    # Use original scale (Score_post) for calculation
    representative_time <- outcome_data %>%
      filter(Treatment == "cDBS") %>%
      pull(Score_post) %>%
      mean(na.rm = TRUE)
    
    # If no cDBS data, use overall mean
    if (is.na(representative_time) || is.infinite(representative_time)) {
      representative_time <- outcome_data %>%
        pull(Score_post) %>%
        mean(na.rm = TRUE)
    }
    
    # Ensure positive value
    representative_time <- max(representative_time, 1.0)
    
    cat(sprintf("Representative ON time for prior conversion: %.2f hours\n", representative_time))
    
    if (prior_type == "weak") {
      # Weak prior on time scale: Normal(0, 4.0) hours
      mu_time <- 0
      sd_time <- 4.0
    } else if (prior_type == "literature") {
      # Literature-informed prior on time scale: Normal(1.5, 3.0) hours
      mu_time <- 1.5
      sd_time <- 3.0
    }
    
    # Convert prior from time scale to log1p scale
    # Mean transformation: log1p(representative_time + mu_time) - log1p(representative_time)
    mu_log1p <- log1p(representative_time + mu_time) - log1p(representative_time)
    
    # Standard deviation transformation: approximate using gradient method
    # The gradient of log1p(x) is 1/(1+x)
    gradient_at_rep <- 1 / (1 + representative_time)
    sd_log1p <- sd_time * gradient_at_rep
    
    cat(sprintf("Prior on time scale: Normal(%.2f, %.2f) hours\n", mu_time, sd_time))
    cat(sprintf("Prior on log1p scale: Normal(%.4f, %.4f)\n", mu_log1p, sd_log1p))
    
    return(list(mu = mu_log1p, sd = sd_log1p, 
                mu_time = mu_time, sd_time = sd_time,
                representative_time = representative_time))
    
  } else if (outcome == "UPDRS_part_3") {
    # UPDRS3: no transformation needed, prior on original scale
    if (prior_type == "weak") {
      return(list(mu = 0, sd = 6))
    } else if (prior_type == "literature") {
      return(list(mu = -3, sd = 4))
    }
  }
  stop("Unknown outcome or prior type")
}

# =============================================================================
# Function to run Bayesian analysis for a single outcome and prior
# =============================================================================

run_bayesian_analysis <- function(outcome, prior_type, data) {
  
  cat(sprintf("\n=== Bayesian Analysis: %s with %s prior ===\n", outcome, prior_type))
  
  # Filter data for this outcome
  outcome_data <- data %>%
    filter(Evaluation == outcome) %>%
    filter(!is.na(Score_transformed) & !is.na(Baseline_z))
  
  if(nrow(outcome_data) < 10) {
    cat("Insufficient data for analysis\n")
    return(NULL)
  }
  
  cat(sprintf("N observations: %d\n", nrow(outcome_data)))
  cat(sprintf("N patients: %d\n", length(unique(outcome_data$PatientID))))
  
  # Get prior specification and extract numeric values
  # For ON time, converts time-scale prior to log1p scale
  prior_spec <- get_prior_spec(outcome, prior_type, outcome_data)
  prior_mu_val <- as.numeric(prior_spec$mu)
  prior_sd_val <- as.numeric(prior_spec$sd)
  
  # Set up prior for Treatment effect (aDBS - cDBS)
  # With cDBS as reference, the coefficient for aDBS effect is "b_TreatmentaDBS"
  treatment_prior <- set_prior(
    paste0("normal(", prior_mu_val, ", ", prior_sd_val, ")"),
    class = "b",
    coef = "TreatmentaDBS"
  )
  
  # Formula (same as main ANCOVA)
  # For ON time: Score_transformed is on log1p scale
  formula_bayes <- Score_transformed ~ Treatment + Period + Sequence + Baseline_z + (1|PatientID)
  
  cat("Fitting Bayesian model...\n")
  if (outcome == "Mean_corr_ON_time") {
    cat(sprintf("Prior for Treatment effect on log1p scale: Normal(%.4f, %.4f)\n", 
                prior_spec$mu, prior_spec$sd))
    if (!is.null(prior_spec$mu_time)) {
      cat(sprintf("  (corresponds to Normal(%.2f, %.2f) hours on original time scale)\n",
                  prior_spec$mu_time, prior_spec$sd_time))
    }
  } else {
    cat(sprintf("Prior for Treatment effect: Normal(%.2f, %.2f)\n", prior_spec$mu, prior_spec$sd))
  }
  
  # Fit Bayesian model
  fit_bayes <- tryCatch({
    brm(
      formula = formula_bayes,
      data = outcome_data,
      family = gaussian(),
      prior = treatment_prior,
      chains = 4,
      iter = 4000,
      warmup = 2000,
      cores = 4,
      control = list(adapt_delta = 0.95),
      seed = 12345,
      save_pars = save_pars(all = TRUE)
    )
  }, error = function(e) {
    cat(sprintf("Model fitting failed: %s\n", e$message))
    return(NULL)
  })
  
  if(is.null(fit_bayes)) {
    return(NULL)
  }
  
  # ============================
  # Convergence diagnostics: Check model reliability
  # ============================
  cat("\n=== Convergence Diagnostics ===\n")
  summary_fit <- summary(fit_bayes)
  
  # Check Rhat: convergence statistic (should be < 1.01) for both fixed and random effects
  rhat_fixed <- summary_fit$fixed$Rhat
  max_rhat_fixed <- max(rhat_fixed, na.rm = TRUE)
  
  if (!is.null(summary_fit$random)) {
    rhat_random <- summary_fit$random$Rhat
    if (length(rhat_random) > 0) {
      max_rhat_random <- max(rhat_random, na.rm = TRUE)
      max_rhat <- max(max_rhat_fixed, max_rhat_random, na.rm = TRUE)
    } else {
      max_rhat <- max_rhat_fixed
    }
  } else {
    max_rhat <- max_rhat_fixed
  }
  
  cat(sprintf("Max Rhat (fixed/random): %.4f (target: < 1.01)\n", max_rhat))
  if (max_rhat > 1.01) {
    warning(sprintf("High Rhat detected: %.4f. Consider increasing iterations or checking model.\n", max_rhat))
  }
  
  # Effective Sample Size (ESS): indicates how many independent samples were obtained
  n_eff_values <- summary_fit$fixed$Bulk_ESS
  min_n_eff <- min(n_eff_values, na.rm = TRUE)
  cat(sprintf("Min Bulk ESS: %.0f (target: > 400)\n", min_n_eff))
  if (min_n_eff < 400) {
    warning(sprintf("Low Bulk ESS detected: %.0f. Consider increasing iterations.\n", min_n_eff))
  }
  
  # Tail ESS: effective sample size for tail regions of the distribution
  min_tail_ess <- NA_real_
  if (!is.null(summary_fit$fixed$Tail_ESS)) {
    tail_ess_values <- summary_fit$fixed$Tail_ESS
    min_tail_ess <- min(tail_ess_values, na.rm = TRUE)
    cat(sprintf("Min Tail ESS: %.0f (target: > 400)\n", min_tail_ess))
    if (min_tail_ess < 400) {
      warning(sprintf("Low Tail ESS detected (tail): %.0f. Consider increasing iterations.\n", min_tail_ess))
    }
  } else {
    cat("Tail ESS not available in summary; skipping Tail ESS check.\n")
  }
  
  # Divergent transitions: indicates potential problems with MCMC sampling (should be 0)
  n_divergent <- NA_integer_
  
  if (requireNamespace("posterior", quietly = TRUE)) {
    # Check divergent transitions using posterior::nuts_params()
    tryCatch({
      nuts_diag <- posterior::nuts_params(fit_bayes)
      if ("divergent__" %in% names(nuts_diag)) {
        n_divergent <- sum(nuts_diag$divergent__, na.rm = TRUE)
      }
    }, error = function(e) {
      n_divergent <<- NA_integer_
    })
  }
  
  # Alternative method if posterior package is unavailable
  if (is.na(n_divergent) && requireNamespace("rstan", quietly = TRUE)) {
    tryCatch({
      sampler_params <- rstan::get_sampler_params(fit_bayes$fit, inc_warmup = FALSE)
      n_divergent <- sum(sapply(sampler_params, function(x) sum(x[, "divergent__"], na.rm = TRUE)))
    }, error = function(e) {
      n_divergent <<- NA_integer_
    })
  }
  
  if (!is.na(n_divergent)) {
    cat(sprintf("Divergent transitions: %d (target: 0)\n", n_divergent))
    if (n_divergent > 0) {
      warning(sprintf("Divergent transitions detected: %d. Consider increasing adapt_delta (e.g., 0.99).\n", n_divergent))
    }
  } else {
    cat("Divergent transitions: could not be determined automatically. Check brms warnings above.\n")
    # NOTE: Do not set to 0 here. Keep as NA to indicate "unknown".
  }
  
  # Extract posterior samples for Treatment effect
  draws <- as_draws_df(fit_bayes)
  treatment_samples <- draws$b_TreatmentaDBS
  
  # Calculate posterior summaries (transformed scale)
  posterior_mean <- mean(treatment_samples)
  posterior_sd   <- sd(treatment_samples)
  posterior_ci   <- quantile(treatment_samples, probs = c(0.025, 0.975))
  posterior_ci_lower <- posterior_ci[1]
  posterior_ci_upper <- posterior_ci[2]
  
  # Calculate probabilities based on outcome direction (both directions)
  # For ON time, calculate MCID on original scale using marginal means approach:
  # get treatment-specific means and back-transform to original scale
  if (outcome == "Mean_corr_ON_time") {
    # Both directions on transformed scale
    p_direction_positive <- mean(treatment_samples > 0)
    p_direction_negative <- mean(treatment_samples < 0)
    # Keep for backward compatibility
    p_direction <- p_direction_positive
    
    # For MCID, calculate on original scale
    # Get posterior samples of average treatment effects (marginal means) for each treatment
    # These are averaged over all patients and study periods
    newdata_aDBS <- outcome_data %>%
      mutate(Treatment = "aDBS") %>%
      select(Treatment, Period, Sequence, Baseline_z) %>%
      distinct()
    
    newdata_cDBS <- outcome_data %>%
      mutate(Treatment = "cDBS") %>%
      select(Treatment, Period, Sequence, Baseline_z) %>%
      distinct()
    
    # Get fitted values (marginal means) for each treatment
    # Use a representative baseline value (mean of Baseline_z, which is 0 after z-scoring)
    # Create all combinations of Period and Sequence for marginal means
    newdata_rep <- outcome_data %>%
      select(Period, Sequence) %>%
      distinct() %>%
      mutate(Baseline_z = 0)  # Use mean baseline (z=0)
    
    # Create data for each treatment
    newdata_aDBS_rep <- newdata_rep %>%
      mutate(Treatment = "aDBS")
    newdata_cDBS_rep <- newdata_rep %>%
      mutate(Treatment = "cDBS")
    
    # Get posterior predictions for each treatment (average across all patients)
    fitted_aDBS <- fitted(fit_bayes, newdata = newdata_aDBS_rep, 
                         summary = FALSE, re_formula = NA)  # NA = average across patients
    fitted_cDBS <- fitted(fit_bayes, newdata = newdata_cDBS_rep, 
                         summary = FALSE, re_formula = NA)
    
    # Calculate mean across all study period/sequence combinations
    # Each row represents one posterior sample, each column represents one data combination
    mean_aDBS_log1p <- apply(fitted_aDBS, 1, mean)
    mean_cDBS_log1p <- apply(fitted_cDBS, 1, mean)
    
    # Convert back to original time scale (hours): exp(log1p_value) - 1
    mean_aDBS_original <- exp(mean_aDBS_log1p) - 1
    mean_cDBS_original <- exp(mean_cDBS_log1p) - 1
    
    # Calculate difference on original scale
    diff_original <- mean_aDBS_original - mean_cDBS_original
    
    # Both directions for MCID on original scale
    p_mcid_positive <- mean(diff_original > mcid_values[[outcome]])  # P(Δ > +2h)
    p_mcid_negative <- mean(diff_original < -mcid_values[[outcome]])  # P(Δ < -2h)
    # Keep for backward compatibility
    p_mcid <- p_mcid_positive
    
    # Store original scale difference for reporting
    mcid_threshold_transformed <- NA  # Not used for ON time (original scale is used)
  } else if (outcome == "UPDRS_part_3") {
    # Both directions
    p_direction_positive <- mean(treatment_samples > 0)
    p_direction_negative <- mean(treatment_samples < 0)
    # Keep for backward compatibility
    p_direction <- p_direction_negative
    
    # Both directions for MCID
    p_mcid_positive <- mean(treatment_samples > abs(mcid_values[[outcome]]))  # P(Δ > +5)
    p_mcid_negative <- mean(treatment_samples < mcid_values[[outcome]])  # P(Δ < -5)
    # Keep for backward compatibility
    p_mcid <- p_mcid_negative
  } else {
    stop("Unknown outcome for probability calculation")
  }
  
  cat("\n=== Posterior Summary ===\n")
  cat(sprintf("Posterior mean (Δ on transformed scale): %.4f\n", posterior_mean))
  cat(sprintf("95%% CrI (transformed scale): [%.4f, %.4f]\n", posterior_ci_lower, posterior_ci_upper))
  if (outcome == "Mean_corr_ON_time") {
    cat(sprintf("P(Δ > 0 on transformed scale): %.4f\n", p_direction_positive))
    cat(sprintf("P(Δ < 0 on transformed scale): %.4f\n", p_direction_negative))
    if(exists("diff_original")) {
      cat(sprintf("Mean difference on original scale: %.2f hours\n", mean(diff_original)))
      cat(sprintf("95%% CrI on original scale: [%.2f, %.2f] hours\n", 
                  quantile(diff_original, 0.025), quantile(diff_original, 0.975)))
    }
    cat(sprintf("P(Δ > +2h on original scale): %.4f\n", p_mcid_positive))
    cat(sprintf("P(Δ < -2h on original scale): %.4f\n", p_mcid_negative))
  } else {
    cat(sprintf("P(Δ > 0): %.4f\n", p_direction_positive))
    cat(sprintf("P(Δ < 0): %.4f\n", p_direction_negative))
    cat(sprintf("P(Δ > +5): %.4f\n", p_mcid_positive))
    cat(sprintf("P(Δ < -5): %.4f\n", p_mcid_negative))
  }
  
  # Store additional information for ON time
  result_list <- list(
    outcome = outcome,
    prior_type = prior_type,
    fit = fit_bayes,
    posterior_samples = treatment_samples,
    posterior_mean = posterior_mean,
    posterior_ci_lower = posterior_ci_lower,
    posterior_ci_upper = posterior_ci_upper,
    # Both directions for probabilities
    p_direction_positive = p_direction_positive,
    p_direction_negative = p_direction_negative,
    p_direction = p_direction,  # Keep for backward compatibility
    p_mcid_positive = p_mcid_positive,
    p_mcid_negative = p_mcid_negative,
    p_mcid = p_mcid,  # Keep for backward compatibility
    data = outcome_data,
    max_rhat = max_rhat,
    min_n_eff = min_n_eff,
    min_tail_ess = min_tail_ess,
    n_divergent = n_divergent
  )
  
  # For ON time, also store original scale difference samples
  if (outcome == "Mean_corr_ON_time") {
    result_list$diff_original_samples <- diff_original
    result_list$mean_aDBS_original <- mean(mean_aDBS_original)
    result_list$mean_cDBS_original <- mean(mean_cDBS_original)
  }
  
  return(result_list)
}

# =============================================================================
# Run analyses for all combinations
# =============================================================================

cat("\n=== Running Bayesian Analyses ===\n")

bayes_results <- list()

for (outcome in target_outcomes) {
  for (prior_type in c("weak", "literature")) {
    key <- paste(outcome, prior_type, sep = "_")
    cat(sprintf("\n--- Processing: %s ---\n", key))
    
    result <- run_bayesian_analysis(outcome, prior_type, ancova_data)
    if(!is.null(result)) {
      bayes_results[[key]] <- result
    }
  }
}

# =============================================================================
# Compile results into summary table
# =============================================================================

cat("\n=== Compiling Results Summary ===\n")

results_summary <- map_dfr(bayes_results, function(res) {
  result_row <- tibble(
    Outcome = res$outcome,
    Prior = res$prior_type,
    Posterior_Mean_Transformed = res$posterior_mean,
    Posterior_SD_Transformed   = sd(res$posterior_samples),
    CrI_Lower_Transformed = res$posterior_ci_lower,
    CrI_Upper_Transformed = res$posterior_ci_upper,
    # Both directions for probabilities
    P_Direction_Positive = res$p_direction_positive,
    P_Direction_Negative = res$p_direction_negative,
    P_Direction = res$p_direction,  # Keep for backward compatibility
    P_MCID_Positive = res$p_mcid_positive,
    P_MCID_Negative = res$p_mcid_negative,
    P_MCID = res$p_mcid,  # Keep for backward compatibility
    Max_Rhat = res$max_rhat,
    Min_n_eff = res$min_n_eff,
    Min_Tail_ESS = res$min_tail_ess,
    N_Divergent = res$n_divergent
  )
  
  # Add original scale information for ON time
  if (res$outcome == "Mean_corr_ON_time" && !is.null(res$diff_original_samples)) {
    result_row <- result_row %>%
      mutate(
        Posterior_Mean_Original = mean(res$diff_original_samples),
        Posterior_SD_Original   = sd(res$diff_original_samples),
        CrI_Lower_Original = quantile(res$diff_original_samples, 0.025),
        CrI_Upper_Original = quantile(res$diff_original_samples, 0.975),
        Mean_aDBS_Original = res$mean_aDBS_original,
        Mean_cDBS_Original = res$mean_cDBS_original
      )
  } else {
    result_row <- result_row %>%
      mutate(
        Posterior_Mean_Original = NA_real_,
        Posterior_SD_Original   = NA_real_,
        CrI_Lower_Original = NA_real_,
        CrI_Upper_Original = NA_real_,
        Mean_aDBS_Original = NA_real_,
        Mean_cDBS_Original = NA_real_
      )
  }
  
  return(result_row)
})

print(results_summary)

# Save results
write_csv(results_summary, "results_bayesian_ancova_summary.csv")
saveRDS(bayes_results, "results_bayesian_ancova.rds")

# Also save convergence diagnostics separately for easy reference
convergence_summary <- results_summary %>%
  select(Outcome, Prior, Max_Rhat, Min_n_eff, Min_Tail_ESS, N_Divergent) %>%
  mutate(
    Rhat_OK = Max_Rhat < 1.01,
    ESS_OK = Min_n_eff > 400 & (is.na(Min_Tail_ESS) | Min_Tail_ESS > 400),
    # Divergent_OK: TRUE only when explicitly 0, NA means unknown -> FALSE
    Divergent_OK = (N_Divergent == 0) & !is.na(N_Divergent)
  )

write_csv(convergence_summary, "results_bayesian_convergence_diagnostics.csv")

cat("\nResults saved to:\n")
cat("1. Bayesian_ANCOVA_Results_Summary.csv - Full results with posterior summaries\n")
cat("2. Bayesian_ANCOVA_Results.rds - Complete analysis objects\n")
cat("3. Bayesian_Convergence_Diagnostics.csv - Convergence diagnostics (Rhat, ESS, Divergent transitions)\n")

# =============================================================================
# Create posterior density plots
# =============================================================================

cat("\n=== Creating Posterior Density Plots ===\n")

# Function to create posterior density plot
create_posterior_plot <- function(res) {
  
  outcome <- res$outcome
  prior_type <- res$prior_type
  
  # For ON time, use original scale (hours) for plotting
  if (outcome == "Mean_corr_ON_time") {
    # Use original scale samples if available
    if (!is.null(res$diff_original_samples)) {
      plot_samples <- res$diff_original_samples
      x_label <- "Treatment Effect (aDBS - cDBS) in hours"
      # MCID thresholds on original scale: +2h and -2h (both sides)
      mcid_threshold_positive <- 2.0
      mcid_threshold_negative <- -2.0
      # Calculate probabilities on original scale
      p_positive <- mean(plot_samples > 0)
      p_negative <- mean(plot_samples < 0)
      p_mcid_positive <- mean(plot_samples > mcid_threshold_positive)
      p_mcid_negative <- mean(plot_samples < mcid_threshold_negative)
      direction_label <- sprintf("P(Δ > 0) = %.3f, P(Δ < 0) = %.3f", p_positive, p_negative)
      mcid_prob_label <- sprintf("P(Δ > +2h) = %.3f, P(Δ < -2h) = %.3f", p_mcid_positive, p_mcid_negative)
    } else {
      # Fallback to transformed scale if original scale not available
      plot_samples <- res$posterior_samples
      x_label <- "Treatment Effect (aDBS - cDBS) on log1p scale"
      mcid_threshold_positive <- NA
      mcid_threshold_negative <- NA
      direction_label <- sprintf("P(Δ > 0) = %.3f", res$p_direction)
      mcid_prob_label <- sprintf("P(Δ > MCID) = %.3f (on original scale)", res$p_mcid)
    }
  } else if (outcome == "UPDRS_part_3") {
    plot_samples <- res$posterior_samples
    mcid_threshold_positive <- NA  # Not applicable for UPDRS3
    mcid_threshold_negative <- mcid_values[[outcome]]  # -5 points
    x_label <- "Treatment Effect (aDBS - cDBS)"
    p_positive <- mean(plot_samples > 0)
    p_negative <- mean(plot_samples < 0)
    p_mcid_negative <- mean(plot_samples < mcid_threshold_negative)
    p_mcid_positive <- mean(plot_samples > abs(mcid_threshold_negative))  # Opposite side: > +5
    direction_label <- sprintf("P(Δ < 0) = %.3f, P(Δ > 0) = %.3f", p_negative, p_positive)
    mcid_prob_label <- sprintf("P(Δ < -5) = %.3f, P(Δ > +5) = %.3f", p_mcid_negative, p_mcid_positive)
  }
  
  # Create data frame for plotting
  plot_data <- data.frame(delta = plot_samples)
  
  # Calculate density for shading
  dens <- density(plot_samples)
  dens_df <- data.frame(x = dens$x, y = dens$y)
  
  # Determine shaded regions (both directions and both MCID thresholds)
  if (outcome == "Mean_corr_ON_time") {
    # Shade areas for both directions
    dens_df$shade_positive <- dens_df$x > 0
    dens_df$shade_negative <- dens_df$x < 0
    # Shade areas for MCID thresholds (both sides)
    if (!is.na(mcid_threshold_positive)) {
      dens_df$shade_mcid_positive <- dens_df$x > mcid_threshold_positive
      dens_df$shade_mcid_negative <- dens_df$x < mcid_threshold_negative
    } else {
      dens_df$shade_mcid_positive <- FALSE
      dens_df$shade_mcid_negative <- FALSE
    }
  } else {
    # UPDRS3: shade areas for both directions
    dens_df$shade_positive <- dens_df$x > 0
    dens_df$shade_negative <- dens_df$x < 0
    # MCID thresholds
    dens_df$shade_mcid_negative <- dens_df$x < mcid_threshold_negative
    if (!is.na(mcid_threshold_positive)) {
      dens_df$shade_mcid_positive <- dens_df$x > mcid_threshold_positive
    } else {
      dens_df$shade_mcid_positive <- dens_df$x > abs(mcid_threshold_negative)  # Opposite side
    }
  }
  
  # Create plot
  p <- ggplot(plot_data, aes(x = delta)) +
    # Shaded area for positive direction (blue)
    geom_ribbon(data = dens_df[dens_df$shade_positive, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "steelblue", alpha = 0.2, inherit.aes = FALSE) +
    # Shaded area for negative direction (red)
    geom_ribbon(data = dens_df[dens_df$shade_negative, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "coral", alpha = 0.2, inherit.aes = FALSE) +
    # Shaded area for positive MCID threshold (green)
    {if(any(dens_df$shade_mcid_positive)) geom_ribbon(data = dens_df[dens_df$shade_mcid_positive, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "darkgreen", alpha = 0.3, inherit.aes = FALSE)} +
    # Shaded area for negative MCID threshold (green)
    {if(any(dens_df$shade_mcid_negative)) geom_ribbon(data = dens_df[dens_df$shade_mcid_negative, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "darkgreen", alpha = 0.3, inherit.aes = FALSE)} +
    # Density curve
    geom_density(fill = "gray70", alpha = 0.5, color = "black", linewidth = 0.8) +
    # Vertical line at 0
    geom_vline(xintercept = 0, linetype = "solid", color = "black", linewidth = 0.8) +
    # Vertical lines at MCID thresholds (both sides)
    {if(!is.na(mcid_threshold_positive)) geom_vline(xintercept = mcid_threshold_positive, 
                linetype = "dashed", color = "darkgreen", linewidth = 0.8)} +
    {if(!is.na(mcid_threshold_negative)) geom_vline(xintercept = mcid_threshold_negative, 
                linetype = "dashed", color = "darkgreen", linewidth = 0.8)} +
    # Labels and theme
    labs(
      title = sprintf("Posterior Distribution: %s (%s prior)", 
                     ifelse(outcome == "Mean_corr_ON_time", "ON Time", "UPDRS Part III"),
                     prior_type),
      x = x_label,
      y = "Posterior Density",
      caption = sprintf("%s\n%s\nShaded areas: blue = positive, red = negative, green = MCID thresholds",
                       direction_label, mcid_prob_label)
    ) +
    theme_classic(base_size = 11) +
    theme(
      plot.title = element_text(size = 12, face = "bold"),
      plot.caption = element_text(size = 9, color = "gray50", hjust = 0),
      axis.text = element_text(size = 10),
      axis.title = element_text(size = 10)
    )
  
  return(p)
}

# Create individual plots
for (key in names(bayes_results)) {
  res <- bayes_results[[key]]
  p <- create_posterior_plot(res)
  
  filename <- sprintf("figure_bayesian_posterior_%s_%s.pdf", res$outcome, res$prior_type)
  
  ggsave(filename, p, width = 4, height = 3, bg = "white")
  cat(sprintf("Saved: %s\n", filename))
}

# =============================================================================
# Create stacked plots (same outcome, different priors, aligned x-axis)
# =============================================================================

cat("\n=== Creating Stacked Plots (Aligned X-axis) ===\n")

# Function to create stacked plot for same outcome
create_stacked_plot <- function(outcome, bayes_results) {
  
  # Get both prior results for this outcome
  weak_key <- paste(outcome, "weak", sep = "_")
  lit_key <- paste(outcome, "literature", sep = "_")
  
  if (!(weak_key %in% names(bayes_results) && lit_key %in% names(bayes_results))) {
    cat(sprintf("Warning: Missing results for %s\n", outcome))
    return(NULL)
  }
  
  res_weak <- bayes_results[[weak_key]]
  res_lit <- bayes_results[[lit_key]]
  
  # Determine which scale to use and get samples
  if (outcome == "Mean_corr_ON_time") {
    # Use original scale if available
    if (!is.null(res_weak$diff_original_samples) && !is.null(res_lit$diff_original_samples)) {
      samples_weak <- res_weak$diff_original_samples
      samples_lit <- res_lit$diff_original_samples
      x_label <- "Treatment Effect (aDBS - cDBS) in hours"
      mcid_pos <- 2.0
      mcid_neg <- -2.0
    } else {
      samples_weak <- res_weak$posterior_samples
      samples_lit <- res_lit$posterior_samples
      x_label <- "Treatment Effect (aDBS - cDBS) on log1p scale"
      mcid_pos <- NA
      mcid_neg <- NA
    }
  } else {
    samples_weak <- res_weak$posterior_samples
    samples_lit <- res_lit$posterior_samples
    x_label <- "Treatment Effect (aDBS - cDBS)"
    mcid_pos <- NA
    mcid_neg <- mcid_values[[outcome]]  # -5 for UPDRS3
  }
  
  # Calculate unified x-axis limits
  all_samples <- c(samples_weak, samples_lit)
  x_range <- range(all_samples, na.rm = TRUE)
  x_margin <- diff(x_range) * 0.1
  x_limits <- c(x_range[1] - x_margin, x_range[2] + x_margin)
  
  # Create plot data
  plot_data_weak <- data.frame(delta = samples_weak, prior = "weak")
  plot_data_lit <- data.frame(delta = samples_lit, prior = "literature")
  
  # Calculate densities
  dens_weak <- density(samples_weak)
  dens_lit <- density(samples_lit)
  
  dens_df_weak <- data.frame(x = dens_weak$x, y = dens_weak$y, prior = "weak")
  dens_df_lit <- data.frame(x = dens_lit$x, y = dens_lit$y, prior = "literature")
  
  # Determine shaded regions for weak prior
  if (outcome == "Mean_corr_ON_time" && !is.null(res_weak$diff_original_samples)) {
    dens_df_weak$shade_positive <- dens_df_weak$x > 0
    dens_df_weak$shade_negative <- dens_df_weak$x < 0
    if (!is.na(mcid_pos)) {
      dens_df_weak$shade_mcid_positive <- dens_df_weak$x > mcid_pos
      dens_df_weak$shade_mcid_negative <- dens_df_weak$x < mcid_neg
    } else {
      dens_df_weak$shade_mcid_positive <- FALSE
      dens_df_weak$shade_mcid_negative <- FALSE
    }
  } else if (outcome == "UPDRS_part_3") {
    dens_df_weak$shade_positive <- dens_df_weak$x > 0
    dens_df_weak$shade_negative <- dens_df_weak$x < 0
    dens_df_weak$shade_mcid_negative <- dens_df_weak$x < mcid_neg
    dens_df_weak$shade_mcid_positive <- dens_df_weak$x > abs(mcid_neg)
  }
  
  # Determine shaded regions for literature prior
  if (outcome == "Mean_corr_ON_time" && !is.null(res_lit$diff_original_samples)) {
    dens_df_lit$shade_positive <- dens_df_lit$x > 0
    dens_df_lit$shade_negative <- dens_df_lit$x < 0
    if (!is.na(mcid_pos)) {
      dens_df_lit$shade_mcid_positive <- dens_df_lit$x > mcid_pos
      dens_df_lit$shade_mcid_negative <- dens_df_lit$x < mcid_neg
    } else {
      dens_df_lit$shade_mcid_positive <- FALSE
      dens_df_lit$shade_mcid_negative <- FALSE
    }
  } else if (outcome == "UPDRS_part_3") {
    dens_df_lit$shade_positive <- dens_df_lit$x > 0
    dens_df_lit$shade_negative <- dens_df_lit$x < 0
    dens_df_lit$shade_mcid_negative <- dens_df_lit$x < mcid_neg
    dens_df_lit$shade_mcid_positive <- dens_df_lit$x > abs(mcid_neg)
  }
  
  # Create plot for weak prior
  p_weak <- ggplot(plot_data_weak, aes(x = delta)) +
    geom_ribbon(data = dens_df_weak[dens_df_weak$shade_positive, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "steelblue", alpha = 0.2, inherit.aes = FALSE) +
    geom_ribbon(data = dens_df_weak[dens_df_weak$shade_negative, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "coral", alpha = 0.2, inherit.aes = FALSE) +
    {if(any(dens_df_weak$shade_mcid_positive)) geom_ribbon(data = dens_df_weak[dens_df_weak$shade_mcid_positive, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "darkgreen", alpha = 0.3, inherit.aes = FALSE)} +
    {if(any(dens_df_weak$shade_mcid_negative)) geom_ribbon(data = dens_df_weak[dens_df_weak$shade_mcid_negative, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "darkgreen", alpha = 0.3, inherit.aes = FALSE)} +
    geom_density(fill = "gray70", alpha = 0.5, color = "black", linewidth = 0.8) +
    geom_vline(xintercept = 0, linetype = "solid", color = "black", linewidth = 0.8) +
    {if(!is.na(mcid_pos)) geom_vline(xintercept = mcid_pos, linetype = "dashed", color = "darkgreen", linewidth = 0.8)} +
    {if(!is.na(mcid_neg)) geom_vline(xintercept = mcid_neg, linetype = "dashed", color = "darkgreen", linewidth = 0.8)} +
    scale_x_continuous(limits = x_limits) +
    labs(
      title = sprintf("%s (weak prior)", 
                     ifelse(outcome == "Mean_corr_ON_time", "ON Time", "UPDRS Part III")),
      x = "",
      y = "Posterior Density"
    ) +
    theme_classic(base_size = 11) +
    theme(
      plot.title = element_text(size = 11, face = "bold"),
      axis.text.x = element_blank(),
      axis.ticks.x = element_blank(),
      axis.text = element_text(size = 9),
      axis.title = element_text(size = 9)
    )
  
  # Create plot for literature prior
  p_lit <- ggplot(plot_data_lit, aes(x = delta)) +
    geom_ribbon(data = dens_df_lit[dens_df_lit$shade_positive, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "steelblue", alpha = 0.2, inherit.aes = FALSE) +
    geom_ribbon(data = dens_df_lit[dens_df_lit$shade_negative, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "coral", alpha = 0.2, inherit.aes = FALSE) +
    {if(any(dens_df_lit$shade_mcid_positive)) geom_ribbon(data = dens_df_lit[dens_df_lit$shade_mcid_positive, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "darkgreen", alpha = 0.3, inherit.aes = FALSE)} +
    {if(any(dens_df_lit$shade_mcid_negative)) geom_ribbon(data = dens_df_lit[dens_df_lit$shade_mcid_negative, ], 
                aes(x = x, ymin = 0, ymax = y), 
                fill = "darkgreen", alpha = 0.3, inherit.aes = FALSE)} +
    geom_density(fill = "gray70", alpha = 0.5, color = "black", linewidth = 0.8) +
    geom_vline(xintercept = 0, linetype = "solid", color = "black", linewidth = 0.8) +
    {if(!is.na(mcid_pos)) geom_vline(xintercept = mcid_pos, linetype = "dashed", color = "darkgreen", linewidth = 0.8)} +
    {if(!is.na(mcid_neg)) geom_vline(xintercept = mcid_neg, linetype = "dashed", color = "darkgreen", linewidth = 0.8)} +
    scale_x_continuous(limits = x_limits) +
    labs(
      title = sprintf("%s (literature-informed prior)", 
                     ifelse(outcome == "Mean_corr_ON_time", "ON Time", "UPDRS Part III")),
      x = x_label,
      y = "Posterior Density"
    ) +
    theme_classic(base_size = 11) +
    theme(
      plot.title = element_text(size = 11, face = "bold"),
      axis.text = element_text(size = 9),
      axis.title = element_text(size = 9)
    )
  
  # Stack plots vertically
  p_stacked <- p_weak / p_lit + 
    plot_layout(heights = c(1, 1))
  
  return(p_stacked)
}

# Create and save stacked plots
for (outcome in target_outcomes) {
  p_stacked <- create_stacked_plot(outcome, bayes_results)
  if (!is.null(p_stacked)) {
    filename <- sprintf("figure_bayesian_posterior_%s_stacked.pdf", outcome)
    
    ggsave(filename, p_stacked, width = 3, height = 2, bg = "white")
    cat(sprintf("Saved: %s\n", filename))
  }
}

cat("\n=== Bayesian Analysis Complete ===\n")
cat("All posterior density plots saved to plots directory.\n")
cat("Individual plots and stacked plots (aligned x-axis) are available.\n")

